# 用于计算测试集y与pre_y的误差
from sklearn.metrics import mean_squared_error, mean_absolute_error  # 评价模型预测效果


class MeanError:
    def __init__(self, y, y_pre):
        super(MeanError, self).__init__()
        self.y = y.cpu().squeeze().numpy()
        self.y_pre = y_pre.cpu().squeeze().numpy()

    def mean_sq_error(self):
        return mean_absolute_error(self.y, self.y_pre)

    def mean_ab_error(self):
        return mean_squared_error(self.y, self.y_pre)